{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "authorship_tag": "ABX9TyMPin07iIA2oewCCP9ZTz6w",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/vinthony/video-retalking/blob/main/quick_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## VideoReTalking：Audio-based Lip Synchronization for Talking Head Video Editing In the Wild\n",
        "\n",
        "[Arxiv](https://arxiv.org/abs/2211.14758) | [Project](https://vinthony.github.io/video-retalking/) | [Github](https://github.com/vinthony/video-retalking)\n",
        "\n",
        "Kun Cheng, Xiaodong Cun, Yong Zhang, Menghan Xia, Fei Yin, Mingrui Zhu, Xuan Wang, Jue Wang, Nannan Wang\n",
        "\n",
        "Xidian University, Tencent AI Lab, Tsinghua University\n",
        "\n",
        "*SIGGRAPH Asia 2022 Conferenence Track*\n",
        "\n"
      ],
      "metadata": {
        "id": "NVfkv2BXSpr3"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Installation** (30s)"
      ],
      "metadata": {
        "id": "u9hdPaH6UL_F"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PnKT9goiQ3Hk"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "### make sure that CUDA is available in Edit -> Nootbook settings -> GPU\n",
        "!nvidia-smi\n",
        "\n",
        "!python --version  \n",
        "!apt-get update\n",
        "!apt install ffmpeg &> /dev/null \n",
        "\n",
        "print('Git clone project and install requirements...')\n",
        "!git clone https://github.com/vinthony/video-retalking.git &> /dev/null\n",
        "%cd video-retalking\n",
        "# !pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html\n",
        "!pip install -r requirements.txt"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Download Pretrained Models**"
      ],
      "metadata": {
        "id": "uwJS0eaM61Cq"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title\n",
        "print('Download pre-trained models...')\n",
        "!mkdir ./checkpoints  \n",
        "!wget https://github.com/vinthony/video-retalking/releases/download/v0.0.1/30_net_gen.pth -O ./checkpoints/30_net_gen.pth\n",
        "!wget https://github.com/vinthony/video-retalking/releases/download/v0.0.1/BFM.zip -O ./checkpoints/BFM.zip\n",
        "!wget https://github.com/vinthony/video-retalking/releases/download/v0.0.1/DNet.pt -O ./checkpoints/DNet.pt\n",
        "!wget https://github.com/vinthony/video-retalking/releases/download/v0.0.1/ENet.pth -O ./checkpoints/ENet.pth\n",
        "!wget https://github.com/vinthony/video-retalking/releases/download/v0.0.1/expression.mat -O ./checkpoints/expression.mat\n",
        "!wget https://github.com/vinthony/video-retalking/releases/download/v0.0.1/face3d_pretrain_epoch_20.pth -O ./checkpoints/face3d_pretrain_epoch_20.pth\n",
        "!wget https://github.com/vinthony/video-retalking/releases/download/v0.0.1/GFPGANv1.3.pth -O ./checkpoints/GFPGANv1.3.pth\n",
        "!wget https://github.com/vinthony/video-retalking/releases/download/v0.0.1/GPEN-BFR-512.pth -O ./checkpoints/GPEN-BFR-512.pth\n",
        "!wget https://github.com/vinthony/video-retalking/releases/download/v0.0.1/LNet.pth -O ./checkpoints/LNet.pth\n",
        "!wget https://github.com/vinthony/video-retalking/releases/download/v0.0.1/ParseNet-latest.pth -O ./checkpoints/ParseNet-latest.pth\n",
        "!wget https://github.com/vinthony/video-retalking/releases/download/v0.0.1/RetinaFace-R50.pth -O ./checkpoints/RetinaFace-R50.pth\n",
        "!wget https://github.com/vinthony/video-retalking/releases/download/v0.0.1/shape_predictor_68_face_landmarks.dat -O ./checkpoints/shape_predictor_68_face_landmarks.dat\n",
        "!unzip -d ./checkpoints/BFM ./checkpoints/BFM.zip"
      ],
      "metadata": {
        "id": "x18qYuQY678E"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Inference**\n",
        "\n",
        "`--face`: Input video.\n",
        "\n",
        "`--audio`: Input audio. Both *.wav* and *.mp4* files are supported.\n",
        "\n",
        "You can choose our provided data from `./examples` folder or upload from your local computer.\n",
        "\n",
        "\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "id": "QJRTF4U8UOjv"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title\n",
        "import glob, os, sys\n",
        "import ipywidgets as widgets\n",
        "from IPython.display import HTML\n",
        "from base64 import b64encode\n",
        "\n",
        "print(\"Choose the Video name to edit: (saved in folder 'examples/face')\")\n",
        "vid_list = glob.glob1('examples/face/', '*.mp4')\n",
        "vid_list.sort()\n",
        "default_vid_name = widgets.Dropdown(options=vid_list, value='1.mp4')\n",
        "display(default_vid_name)\n",
        "\n",
        "print(\"Choose the Audio name to edit: (saved in folder 'examples/audio')\")\n",
        "audio_list = glob.glob1('examples/audio/', '*')\n",
        "audio_list.sort()\n",
        "default_audio_name = widgets.Dropdown(options=audio_list, value='1.wav')\n",
        "display(default_audio_name)\n"
      ],
      "metadata": {
        "id": "U-IY-cBSporP",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Visualize the input video and audio:"
      ],
      "metadata": {
        "id": "-MtI_R1bLJ-f"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title\n",
        "input_video_name = './examples/face/{}'.format(default_vid_name.value)\n",
        "input_video_mp4 = open('{}'.format(input_video_name),'rb').read()\n",
        "input_video_data_url = \"data:video/x-m4v;base64,\" + b64encode(input_video_mp4).decode()\n",
        "print('Display input video: {}'.format(input_video_name), file=sys.stderr)\n",
        "display(HTML(\"\"\"\n",
        "  <video width=400 controls>\n",
        "        <source src=\"%s\" type=\"video/mp4\">\n",
        "  </video>\n",
        "  \"\"\" % input_video_data_url))\n",
        "\n",
        "input_audio_name = './examples/audio/{}'.format(default_audio_name.value)\n",
        "input_audio_mp4 = open('{}'.format(input_audio_name),'rb').read()\n",
        "input_audio_data_url = \"data:audio/wav;base64,\" + b64encode(input_audio_mp4).decode()\n",
        "print('Display input audio: {}'.format(input_audio_name), file=sys.stderr)\n",
        "display(HTML(\"\"\"\n",
        "  <audio width=400 controls>\n",
        "        <source src=\"%s\" type=\"audio/wav\">\n",
        "  </audio>\n",
        "  \"\"\" % input_audio_data_url))\n"
      ],
      "metadata": {
        "id": "ljbScdofJyGO",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "input_video_path = 'examples/face/{}'.format(default_vid_name.value)\n",
        "input_audio_path = 'examples/audio/{}'.format(default_audio_name.value)\n",
        "\n",
        "!python3 inference.py \\\n",
        "  --face {input_video_path} \\\n",
        "  --audio {input_audio_path} \\\n",
        "  --outfile results/output.mp4"
      ],
      "metadata": {
        "id": "D7hUwRCyUYEA"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Visualize the output video:"
      ],
      "metadata": {
        "id": "JB5RbKc-njkB"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title\n",
        "# visualize code from makeittalk\n",
        "from IPython.display import HTML\n",
        "from base64 import b64encode\n",
        "import os, sys, glob, cv2, subprocess, platform\n",
        "\n",
        "def read_video(vid_name):\n",
        "  video_stream = cv2.VideoCapture(vid_name)\n",
        "  fps = video_stream.get(cv2.CAP_PROP_FPS)\n",
        "  full_frames = []\n",
        "  while True:\n",
        "    still_reading, frame = video_stream.read()\n",
        "    if not still_reading:\n",
        "        video_stream.release()\n",
        "        break\n",
        "    full_frames.append(frame)\n",
        "  return full_frames, fps\n",
        "\n",
        "input_video_frames, fps = read_video(input_video_path)\n",
        "output_video_frames, _ = read_video('./results/output.mp4')\n",
        "\n",
        "frame_h, frame_w = input_video_frames[0].shape[:-1]\n",
        "out_concat = cv2.VideoWriter('./temp/temp/result_concat.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_w*2, frame_h))\n",
        "for i in range(len(output_video_frames)):\n",
        "  frame_input = input_video_frames[i % len(input_video_frames)]\n",
        "  frame_output = output_video_frames[i]\n",
        "  out_concat.write(cv2.hconcat([frame_input, frame_output]))\n",
        "out_concat.release()\n",
        "\n",
        "command = 'ffmpeg -loglevel error -y -i {} -i {} -strict -2 -q:v 1 {}'.format(input_audio_path, './temp/temp/result_concat.mp4', './results/output_concat_input.mp4')\n",
        "subprocess.call(command, shell=platform.system() != 'Windows')\n",
        "\n",
        "\n",
        "output_video_name = './results/output.mp4'\n",
        "output_video_mp4 = open('{}'.format(output_video_name),'rb').read()\n",
        "output_video_data_url = \"data:video/mp4;base64,\" + b64encode(output_video_mp4).decode()\n",
        "print('Display lip-syncing video: {}'.format(output_video_name), file=sys.stderr)\n",
        "display(HTML(\"\"\"\n",
        "  <video height=400 controls>\n",
        "        <source src=\"%s\" type=\"video/mp4\">\n",
        "  </video>\n",
        "  \"\"\" % output_video_data_url))\n",
        "\n",
        "output_concat_video_name = './results/output_concat_input.mp4'\n",
        "output_concat_video_mp4 = open('{}'.format(output_concat_video_name),'rb').read()\n",
        "output_concat_video_data_url = \"data:video/mp4;base64,\" + b64encode(output_concat_video_mp4).decode()\n",
        "print('Display input video and lip-syncing video: {}'.format(output_concat_video_name), file=sys.stderr)\n",
        "display(HTML(\"\"\"\n",
        "  <video height=400 controls>\n",
        "        <source src=\"%s\" type=\"video/mp4\">\n",
        "  </video>\n",
        "  \"\"\" % output_concat_video_data_url))\n"
      ],
      "metadata": {
        "id": "ravs9UDucMfy",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}